{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 案例：给用户推荐电影"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import scipy.io as sio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['__header__', '__version__', '__globals__', 'Y', 'R'])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "mat = sio.loadmat('data/ex8_movies.mat')\n",
    "mat.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1682, 943), (1682, 943))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y,R = mat['Y'],mat['R']\n",
    "Y.shape,R.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['__header__', '__version__', '__globals__', 'X', 'Theta', 'num_users', 'num_movies', 'num_features'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "param_mat =sio.loadmat('data/ex8_movieParams.mat')\n",
    "param_mat.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X,Theta,nu,nm,nf = param_mat['X'],param_mat['Theta'],param_mat['num_users'],param_mat['num_movies'],param_mat['num_features']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((1682, 10),\n",
       " (943, 10),\n",
       " array([[943]], dtype=uint16),\n",
       " array([[1682]], dtype=uint16),\n",
       " array([[10]], dtype=uint8))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape,Theta.shape,nu,nm,nf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(943, 1682, 10)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nu = int(nu)\n",
    "nm = int(nm)\n",
    "nf = int(nf)\n",
    "nu,nm,nf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 序列化参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def serialize(X,Theta):\n",
    "    \n",
    "    return np.append(X.flatten(),Theta.flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 解序列化参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def deserialize(params,nm,nu,nf):\n",
    "    X = params[:nm*nf].reshape(nm,nf)\n",
    "    Theta = params[nm*nf:].reshape(nu,nf)\n",
    "    return X,Theta"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 代价函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def costFunction(params,Y,R,nm,nu,nf,lamda):\n",
    "    X,Theta = deserialize(params,nm,nu,nf)\n",
    "    error = 0.5 * np.square((X@Theta.T - Y)* R).sum()\n",
    "    reg1 = 0.5 * lamda * np.square(X).sum()\n",
    "    reg2 = 0.5 * lamda * np.square(Theta).sum()\n",
    "    return error + reg1 + reg2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "22.224603725685675"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "users = 4\n",
    "movies = 5\n",
    "features = 3\n",
    "X_sub = X[:movies,:features]\n",
    "Theta_sub = Theta[:users,:features]\n",
    "Y_sub = Y[:movies,:users]\n",
    "R_sub = R[:movies,:users]\n",
    "cost1 = costFunction(serialize(X_sub,Theta_sub),Y_sub,R_sub,movies,users,features,lamda = 0)\n",
    "cost1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "31.344056244274217"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cost2 = costFunction(serialize(X_sub,Theta_sub),Y_sub,R_sub,movies,users,features,lamda = 1.5)\n",
    "cost2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.梯度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
